{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6a5b2993",
   "metadata": {},
   "source": [
    "# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Goal</b>\n",
    "\n",
    "This tutorial showcases how to accelerate finetuning a full [Llama 2](https://huggingface.co/meta-llama/Llama-2-7b-hf) or [Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B) models from Hugging Face by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n",
    "\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "331f476a",
   "metadata": {},
   "source": [
    "## Dependencies for this tutorial\n",
    "\n",
    "Following files and media are necessary to effectively run this tutorial:\n",
    "\n",
    "1. `te_llama.py`\n",
    "    - This file contains the code to load a Hugging Face Llama 2 or Llama 3 checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n",
    "2. `utils.py`\n",
    "    - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n",
    "3. `media/`\n",
    "    - This directory contains the images used in the following tutorial.\n",
    "\n",
    "These packages are necessary to run this tutorial:\n",
    "`pytorch`, `transformer_engine`, `accelerate`, `transformers`, `peft`, `datasets`.\n",
    "\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Note on running the tutorial with Llama 3 weights</b>\n",
    "\n",
    "This tutorial shows the cell outputs when run with Llama 2 7B weights. It can be run with Llama 3 8B weights simply by providing the directory with those weights (in Hugging Face format) instead of Llama 2 7B weights. These two models are almost identical, the biggest difference being the model dimension (the smallest Llama 3 model has 8B parameters, whereas the smallest Llama 2 has 7B), which enables this tutorial to work for both of them.\n",
    "\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44abae4f",
   "metadata": {},
   "source": [
    "## Table of contents\n",
    "1. From \"Transformer\" to \"Llama\"\n",
    "2. Hugging Face's `LlamaModel`\n",
    "    - Hugging Face's `LlamaDecoderLayer`\n",
    "3. [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n",
    "6. [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n",
    "    - Transformer Engine's `TransformerLayer`\n",
    "    - `TransformerLayer` options explained\n",
    "    - Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`\n",
    "7. [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n",
    "8. Conclusion"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e37e2cc1",
   "metadata": {},
   "source": [
    "## From \"Transformer\" to \"Llama\" \n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"media/transformer_llama.png\">\n",
    "    <figcaption> Fig 1: Llama visualized as a transformer. (generated with [Nvidia's AI-foundation models](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/sdxl))</figcaption>\n",
    "</figure>\n",
    "\n",
    "A flashback:\n",
    "\n",
    "- 2017: [\"Attention Is All You Need\"](https://arxiv.org/abs/1706.03762) paper introduced pioneering \"Transformer\" architecture and changed the NLP field forever.\n",
    "- 2018-2020: Emergence of GPT model series that showed causal decoder architectures are great fit for pretraining, few-shot and zero-shot learning.\n",
    "- Fast forward to 2023-2024: Following GPT-3/GPT-4 success stories, researchers and companies raced to produce the next best pretrained model that could further be finetuned for application-specific use-cases.\n",
    "- February 2023: Meta releases [Llama 2](https://llama.meta.com/llama2) models (Large Language Model Meta AI). \n",
    "    - These models range from 7B to 70B parameters.\n",
    "    - LLaMA 2 was pretrained on 2 trillion tokens.\n",
    "- April 2024: Meta releases [Llama 3](https://llama.meta.com/llama3) models.\n",
    "    - These models range from 8B to 70B parameters.\n",
    "    - LLaMA 3 was pretrained on 15 trillion tokens.\n",
    "\n",
    "For more information on Llama 2 consider reading the [Huggingface tutorial](https://huggingface.co/blog/llama2). As a quick summary, here are some of the important differences b/w the conventional transformer decoder architecture vs Llama 2 architecture:\n",
    "\n",
    "1. Decoder only model (causal language modeling and next word prediction)\n",
    "2. RMSNorm in place of the LayerNorm\n",
    "3. SwiGLU activation function\n",
    "4. RoPE as positional embeddings \n",
    "5. Grouped Query Attention for the 70B model\n",
    "6. Trained on 4K context length\n",
    "\n",
    "Hugging Face also released a [tutorial about Llama 3](https://huggingface.co/blog/llama3). The key points are:\n",
    "\n",
    "1. Use of bigger tokenizer - 128256 vs 32K.\n",
    "2. Grouped Query Attention is used also by smaller 8B model.\n",
    "3. The context length increased to 8K for all models.\n",
    "3. Llama 3 was trained on 8x more data than Llama 2.\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"media/transformer_vs_llama.svg\">\n",
    "    <figcaption> Fig 2: Comparing GPT and Llama architectures. </figcaption>\n",
    "</figure>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a110de1a",
   "metadata": {},
   "source": [
    "## Hugging Face's `LlamaModel`\n",
    "Hugging Face provides an open-source implementation of `Llama` model in [modeling_llama.py](https://github.com/huggingface/transformers/blob/3d2900e829ab16757632f9dde891f1947cfc4be0/src/transformers/models/llama/modeling_llama.py#L4).\n",
    "\n",
    "Here's a block diagram that shows how Llama model is implemented in the Hugging Face repo. Notice the modular encapsulated form and `LlamaDecoderLayer` at the core of the model implementation.\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"media/llama_for_causal_lm.svg\">\n",
    "    <figcaption> Fig 3: Causal Llama Model Block Diagram. </figcaption>\n",
    "</figure>\n",
    "\n",
    "The above diagram translates to the following text output of the model in PyTorch. Notice that the core of the model has 32 `LlamaDecoderLayer`s. \n",
    "\n",
    "```\n",
    "LlamaForCausalLM(\n",
    "  (model): LlamaModel(\n",
    "    (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n",
    "    (layers): ModuleList(\n",
    "      (0-31): 32 x LlamaDecoderLayer(\n",
    "        (self_attn): LlamaFlashAttention2(\n",
    "          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
    "          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
    "          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
    "          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
    "          (rotary_emb): LlamaRotaryEmbedding()\n",
    "        )\n",
    "        (mlp): LlamaMLP(\n",
    "          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
    "          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
    "          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
    "          (act_fn): SiLU()\n",
    "        )\n",
    "        (input_layernorm): LlamaRMSNorm()\n",
    "        (post_attention_layernorm): LlamaRMSNorm()\n",
    "      )\n",
    "    )\n",
    "    (norm): LlamaRMSNorm()\n",
    "  )\n",
    "  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
    ")\n",
    "```\n",
    "\n",
    "#### Hugging Face's `LlamaDecoderLayer`\n",
    "\n",
    "Let's take a closer look at `LlamaDecoderLayer`. It is composed of `input_layernorm`, `self_attn`, `post_attention_layernorm` and `mlp` modules. Each module has associated weights as shown in the diagram.\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"media/llama_zoom.svg\">\n",
    "    <figcaption> Fig 4: Causal Llama Model Block Diagram (with simplified illustration of the [LlamaDecoderLayer](https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/llama/modeling_llama.py#L695)). </figcaption>\n",
    "</figure>\n",
    "\n",
    "##### Self_Attn Layer\n",
    "For simplicity in the block diagram illustration of the \"self_attn\" box, we omit the \"Grouped Query Attention\" operation and only showcase the modules which have associated weights.\n",
    "   \n",
    "##### MLP Layer\n",
    "\n",
    "SwiGLU is an activation defined as follows in the [modeling_llama.py](https://github.com/huggingface/transformers/blob/7c4995f93d8d24aae05e1e43279c96dce736e5c8/src/transformers/models/llama/modeling_llama.py#L236) file in the Hugging Face github repo:\n",
    "```\n",
    "\"\"\"\n",
    "1. `self.up_proj`, `self.gate_proj` and `self.down_proj` are \"Linear\" layers\n",
    "2. `self.act_fn` is a \"Swish\" function\n",
    "\n",
    "\"\"\"\n",
    "down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n",
    "```\n",
    "It requires a set of 3 weights as compared to 2 weights in conventional \"MLP\" layers e.g. in the traditional transformer or GPT architectures. This is also illustrated in the following figure:\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"media/swiglu.svg\">\n",
    "    <figcaption> Fig 5: A look inside the feedforward layer with <code>swiglu</code> activation function. </figcaption>\n",
    "</figure>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9529229",
   "metadata": {},
   "source": [
    "## [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n",
    "\n",
    "Llama 2 weights are loaded into the Hugging Face native implementation `LlamaForCausalLM` (refer to [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)). \n",
    "\n",
    "For this and other subsequent runs, the `batch_size` is `8`. The `LlamaDecoderLayer` is left unchanged in the baseline as follows:\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"media/llamadecoderlayer.svg\">\n",
    "    <figcaption> Fig 6: Revisiting \"LlamaDecoderLayer\". </figcaption>\n",
    "</figure>\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "<b>Note</b>\n",
    "\n",
    "The baseline implementation will be run in `BF16` precision.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b38eb3ac",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Note</b>\n",
    "    \n",
    "This tutorial loads and trains a Llama 3 8B or a Llama 2 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n",
    "\n",
    "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n",
    "\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2e9d7a8c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10 finetuning steps complete!\n",
      "Average time taken per step: 248 milliseconds\n"
     ]
    }
   ],
   "source": [
    "# Restart the notebook (to flush the GPU memory)\n",
    "from utils import restart_jupyter_notebook\n",
    "restart_jupyter_notebook()\n",
    "\n",
    "\n",
    "# Import necessary packages, methods and variables\n",
    "from utils import *\n",
    "\n",
    "\n",
    "# Provide Huggingface Access Token\n",
    "hyperparams.hf_access_token = \"\"\n",
    "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
    "\n",
    "# Provide a directory to cache weights in to avoid downloading them every time.\n",
    "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
    "hyperparams.weights_cache_dir = \"\"\n",
    "\n",
    "# For Llama 2, uncomment this line (also set by default)\n",
    "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
    "\n",
    "# For Llama 3, uncomment this line\n",
    "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
    "\n",
    "hyperparams.mixed_precision = \"bf16\"\n",
    "\n",
    "\n",
    "# Init the model and accelerator wrapper\n",
    "model = init_baseline_model(hyperparams)\n",
    "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n",
    "\n",
    "\n",
    "# Finetune the model\n",
    "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4035ccb7",
   "metadata": {},
   "source": [
    "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n",
    "\n",
    "| Models                                                      | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
    "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
    "| HF (baseline)                                       | BF16      | 248                         | 1                       |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3db90dff",
   "metadata": {},
   "source": [
    "## [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n",
    "\n",
    "In addition to basic layers like `Linear` and `LayerNorm`, Transformer Engine offers larger modules like `MultiheadAttention` (combines \"LayerNorm\" and \"Self Attention\") and `LayerNormMLP` (combines \"LayerNorm\" and \"MLP\") that could replace their counterparts in the `LlamaDecoderLayer` and potentially provide a speedup. Transformer Engine also offers a full `TransformerLayer` (which further combines `MultiheadAttention` and `LayerNormMLP` layers) which could replace `LlamaDecoderLayer` and provide a speedup (with careful mapping of the weights since the name of the weights are different for those two layers). Let's take a closer look at Transformer Engine's `TransformerLayer`. \n",
    "\n",
    "#### Transformer Engine's `TransformerLayer`\n",
    "\n",
    "At a higher level, TE's `TransformerLayer` could be visualized as an apt replacement for the `LlamaDecoderLayer`. But the internals of the `TransformerLayer` are organized a bit differently. \n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"media/tellamadecoderlayer.svg\">\n",
    "    <figcaption> Fig 7: Transformer Engine's `TransformerLayer` </figcaption>\n",
    "</figure>\n",
    "\n",
    "Just like Hugging Face's `LlamaDecoderLayer`, Transformer Engine's `TransformerLayer` encapsulates `self_attention` (as `MultiheadAttention`) and `mlp` (as `LayerNormMLP`). A major difference is that the two `Norm`s are included in the `MultiheadAttention` and `LayerNormMLP` layers as shown in the following output prompt:\n",
    "\n",
    "```\n",
    "TransformerLayer(\n",
    "    (self_attention): MultiheadAttention(\n",
    "      (layernorm_qkv): LayerNormLinear()\n",
    "      (core_attention): DotProductAttention()\n",
    "      (proj): Linear()\n",
    "    )\n",
    "    (layernorm_mlp): LayerNormMLP()\n",
    ")\n",
    "```\n",
    "\n",
    "Another difference is that Transformer Engine implements an efficient version of feedforward layer with SwiGLU in which the weights from the `up_proj` and `gate_proj` modules are merged together and SwiGLU is applied using a custom fused kernel. This is done so that only one big and efficient Matrix Multiplication operation is issued to the GPU instead of two smaller ones.\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"media/swiglu_te.svg\">\n",
    "    <figcaption> Fig 8: Abstract illustration of the SwiGLU implementation in Transformer Engine. </figcaption>\n",
    "</figure>\n",
    "\n",
    "#### `TransformerLayer` options explained\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Note</b>\n",
    "    \n",
    "Here, we go over some of the options in `TransformerLayer` that are needed for the tutorial. For a complete list of options, refer the [TransformerLayer API documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html?highlight=transformerlayer#transformer_engine.pytorch.TransformerLayer).\n",
    "\n",
    "</div>\n",
    "\n",
    "In the accompanying `te_llama.py` file, `TELlamaDecoderLayer` is defined as a wrapper over TE's `TransformerLayer` with a few needed options that make `TransformerLayer` a plug-in replacement for the HF's `LlamaDecoderLayer`.\n",
    "\n",
    "```\n",
    "class TELlamaDecoderLayer(te.pytorch.TransformerLayer):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(\n",
    "            config.hidden_size,\n",
    "            config.intermediate_size,\n",
    "            config.num_attention_heads,\n",
    "            bias=False,\n",
    "            layernorm_epsilon=config.rms_norm_eps,\n",
    "            hidden_dropout=0,\n",
    "            attention_dropout=0,\n",
    "            fuse_qkv_params=False,\n",
    "            normalization=\"RMSNorm\",\n",
    "            activation=\"swiglu\",\n",
    "            attn_input_format=\"bshd\",\n",
    "            num_gqa_groups=config.num_key_value_heads,\n",
    "        )\n",
    "        te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)\n",
    "        self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()\n",
    "```\n",
    "\n",
    "Here's a list summarizing each option briefly:\n",
    "\n",
    "1. `hidden_size`: size of each input sample.\n",
    "2. `ffn_hidden_size`: intermediate size to which samples are projected.\n",
    "3. `num_attention_heads`: number of attention heads in the transformer layer.\n",
    "4. `bias`: switch to add additive biases to the submodule layers.\n",
    "5. `layernorm_epsilon`: a value added to the denominator of layer normalization for numerical stability. Default is `1e-5`.\n",
    "6. `hidden_dropout`: dropout probability for the dropout op after FC2 layer (fully connected layer no. 2). Default is `0.1`.\n",
    "7. `attention_dropout`: dropout probability for the dropout op during multi-head attention. Default is `0.1`. \n",
    "8. `fuse_qkv_params`:  if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.\n",
    "9. `normalization`: type of normalization applied. Default is `LayerNorm`.\n",
    "10. `activation`: type of activation used in the MLP block. Default is `gelu`.\n",
    "11. `attn_input_format`: controls whether the dimensions of the intermediate hidden states is 'batch first' ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, `b` batch size, `h` the number of heads, `d` head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules.\n",
    "12. `num_gqa_groups`: number of GQA groups in the transformer layer. Grouped Query Attention is described in [this paper](https://arxiv.org/pdf/2305.13245.pdf). This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention ([MQA](https://arxiv.org/pdf/1911.02150.pdf)), while GQA-H is equivalent to MultiHead Attention, i.e. `num_gqa_groups = num_attention_heads`.\n",
    "\n",
    "\n",
    "Further, note that `RotaryPositionEmbedding` is defined as part of the `TELlamaDecoderLayer` (wrapper around TE's `TransformerLayer`) itself since it expects this rope cache if RoPE is used in the model. \n",
    "\n",
    "Let's revisit how `LlamaDecoderLayer`s form the core of the decoder layer stack in HF's llama implementation:\n",
    "```\n",
    "ModuleList(\n",
    "  (0-31): 32 x LlamaDecoderLayer(\n",
    "    (self_attn): LlamaAttention(\n",
    "      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
    "      (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
    "      (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
    "      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
    "      (rotary_emb): LlamaRotaryEmbedding()\n",
    "    )\n",
    "    (mlp): LlamaMLP(\n",
    "      (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
    "      (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
    "      (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
    "      (act_fn): SiLU()\n",
    "    )\n",
    "    (input_layernorm): LlamaRMSNorm()\n",
    "    (post_attention_layernorm): LlamaRMSNorm()\n",
    "  )\n",
    ")\n",
    "```\n",
    "\n",
    "A major portion of the Hugging Face model implementation (32 `LlamaDecoderLayer` layers) could be potentially replaced with Transformer Engine's `TransformerLayer` layers. Let's see how it is made possible.\n",
    "\n",
    "\n",
    "#### Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`\n",
    "\n",
    "Refer the accompanying file `te_llama.py` which provides a reference to create a Llama 2 model with TE's `TransformerLayer` after replacing HF's `LlamaDecoderLayer`.\n",
    "\n",
    "Briefly, following pieces of code are put together:\n",
    "\n",
    "1. `TELlamaDecoderLayer` is added as a wrapper for `TransformerLayer`. \n",
    "```\n",
    "class TELlamaDecoderLayer(te.pytorch.TransformerLayer):\n",
    "    \"\"\"\n",
    "    Wrapper class over TE's `TransformerLayer`. This makes the wrapper very\n",
    "    similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.\n",
    "\n",
    "    Args:\n",
    "        config: LlamaConfig\n",
    "        args: positional args (for compatibility with `LlamaDecoderLayer`)\n",
    "        kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)\n",
    "    \"\"\"\n",
    "    def __init__(self, config, *args, **kwargs):\n",
    "        super().__init__(\n",
    "            hidden_size=config.hidden_size,\n",
    "            ffn_hidden_size=config.intermediate_size,\n",
    "            num_attention_heads=config.num_attention_heads,\n",
    "            bias=False,\n",
    "            layernorm_epsilon=config.rms_norm_eps,\n",
    "            hidden_dropout=0,\n",
    "            attention_dropout=0,\n",
    "            fuse_qkv_params=False,\n",
    "            normalization=\"RMSNorm\",\n",
    "            activation=\"swiglu\",\n",
    "            attn_input_format=\"bshd\",\n",
    "        )\n",
    "        te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)\n",
    "        self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()\n",
    "\n",
    "    def forward(self,\n",
    "                hidden_states,\n",
    "                *args,\n",
    "                attention_mask,\n",
    "                **kwargs):\n",
    "        \"\"\"\n",
    "        Custom forward to make sure we only pass relevant arguments to the\n",
    "        forward pass of the `TransformerLayer`. Also, make sure the output\n",
    "        format matches the output of the HF's `LlamaDecoderLayer`.\n",
    "        \"\"\"\n",
    "        return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),)\n",
    "```\n",
    "\n",
    "2. Before creating a `LlamaForCausalLM`, `replace_decoder` context manager is used to monkey-patch `LlamaDecoderLayer` with `TELlamaDecoderLayer`.\n",
    "\n",
    "```\n",
    "@contextmanager\n",
    "def replace_decoder(te_decoder_cls):\n",
    "    \"\"\"\n",
    "    Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.\n",
    "    \"\"\"\n",
    "    original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer\n",
    "    transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls\n",
    "    try:\n",
    "        yield\n",
    "    finally:\n",
    "        transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls\n",
    ".\n",
    ".\n",
    ".\n",
    "class TELlamaForCausalLM:\n",
    "    \"\"\"\n",
    "    Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`\n",
    "    class is monkey-patched with `TELlamaDecoderLayer` class before\n",
    "    initializing the causal LM with `LlamaForCausalLM`.\n",
    "\n",
    "    Args:\n",
    "        config: LlamaConfig\n",
    "    \"\"\"\n",
    "\n",
    "    def __new__(cls, config: LlamaConfig):\n",
    "        with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):\n",
    "            llama_for_causal_lm = LlamaForCausalLM(config)\n",
    "        return llama_for_causal_lm\n",
    ".\n",
    ".\n",
    ".\n",
    "```\n",
    "\n",
    "3. A custom `pretrained_from_local` method is added that copies the weights from the checkpoint (which is meant for HF Llama implementation) to the modified `TELlamaForCausalLM` by carefully mapping the weights from the `LlamaDecoderLayer` (HF) to `TransformerLayer` (TE). The method `replace_params` maps and copies apt weights from `LlamaDecoderLayer` to the `TransformerLayer`. Refer to the following diagram for more details.\n",
    "\n",
    "```\n",
    "def replace_params(hf_state_dict, te_state_dict):\n",
    "    # collect all layer prefixes to update\n",
    "    all_layer_prefixes = set()\n",
    "    for param_key in hf_state_dict.keys():\n",
    "        layer_prefix_pat = 'model.layers.\\d+.'\n",
    "        m = re.match(layer_prefix_pat, param_key)\n",
    "        if m is not None:\n",
    "            all_layer_prefixes.add(m.group())\n",
    "\n",
    "    for layer_prefix in all_layer_prefixes:\n",
    "        # When loading weights into models with less number of layers, skip the\n",
    "        # copy if the corresponding layer doesn't exist in TE model\n",
    "        if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict:\n",
    "            te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]\n",
    "\n",
    "        if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict:\n",
    "            te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]\n",
    "\n",
    "        if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict:\n",
    "            te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]\n",
    "    .\n",
    "    .\n",
    "    .\n",
    "\n",
    "    return all_layer_prefixes\n",
    "```\n",
    "\n",
    "The following figure shows how the weights get mapped from the HF's `LlamaDecoderLayer` to TE's `TransformerLayer`.\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"media/weight_swap.svg\">\n",
    "    <figcaption> Fig 9: Replace `LlamaDecoderLayer` with `TransformerLayer`. </figcaption>\n",
    "</figure>\n",
    "\n",
    "After initializing the modified Llama model this way, the core decoder layers get changed to `TELlamaDecoderLayer` (wrapper around `TransformerLayer`) as shown in the following output:\n",
    "```\n",
    "ModuleList(\n",
    "  (0-31): 32 x TELlamaDecoderLayer(\n",
    "    (self_attention): MultiheadAttention(\n",
    "      (layernorm_qkv): LayerNormLinear()\n",
    "      (core_attention): DotProductAttention(\n",
    "        (flash_attention): FlashAttention()\n",
    "        (fused_attention): FusedAttention()\n",
    "        (unfused_attention): UnfusedDotProductAttention(\n",
    "          (scale_mask_softmax): FusedScaleMaskSoftmax()\n",
    "          (attention_dropout): Dropout(p=0, inplace=False)\n",
    "        )\n",
    "      )\n",
    "      (proj): Linear()\n",
    "    )\n",
    "    (layernorm_mlp): LayerNormMLP()\n",
    "  )\n",
    ")\n",
    "```\n",
    "\n",
    "In summary, the model gets changed as follows with a large chunk of the implementation (core decoder layers) coming from Transformer Engine.\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"media/model_change.svg\">\n",
    "    <figcaption> Fig 10: Language model after the HF's `LlamaDecoderLayer`s are replaced with TE's `TransformerLayer`s. </figcaption>\n",
    "</figure>\n",
    "\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "<b>Note</b>\n",
    "\n",
    "Let's first run this \"TELlama\" implementation in `BF16` precision.\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bdb34b91",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10 finetuning steps complete!\n",
      "Average time taken per step: 185 milliseconds\n"
     ]
    }
   ],
   "source": [
    "# Restart the notebook (to flush the GPU memory)\n",
    "from utils import restart_jupyter_notebook\n",
    "restart_jupyter_notebook()\n",
    "\n",
    "\n",
    "# Import necessary packages, methods and variables\n",
    "from utils import *\n",
    "\n",
    "\n",
    "# Provide Huggingface Access Token\n",
    "hyperparams.hf_access_token = \"\"\n",
    "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
    "\n",
    "# Provide a directory to cache weights in to avoid downloading them every time.\n",
    "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
    "hyperparams.weights_cache_dir = \"\"\n",
    "\n",
    "# For Llama 2, uncomment this line (also set by default)\n",
    "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
    "\n",
    "# For Llama 3, uncomment this line\n",
    "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
    "\n",
    "hyperparams.mixed_precision = \"bf16\"\n",
    "\n",
    "\n",
    "# Init the model and accelerator wrapper\n",
    "model = init_te_llama_model(hyperparams)\n",
    "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n",
    "\n",
    "\n",
    "# Finetune the model\n",
    "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c9fbd65",
   "metadata": {},
   "source": [
    "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `LlamaDecoderLayer` gives a speedup of **34%** even when using only BF16 precision!\n",
    "\n",
    "| Models                                                      | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
    "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
    "| HF (baseline)                                               | BF16      | 248                         | 1                       |\n",
    "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16      | 185                         | 1.34                    |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98cd8efb",
   "metadata": {},
   "source": [
    "## [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n",
    "\n",
    "Now that most of the HF Llama model implementation (`LlamaDecoderLayer`s) has been swapped with Transformer Engine implementation (`TELlamaDecoderLayer` or `TransformerLayer`), let's see how finetuning in `FP8` precision helps improve performance.\n",
    "\n",
    "#### How to run the model in `FP8` precision\n",
    "\n",
    "After the substitution, the model can be run in `FP8` precision by the following change over the previous BF16 runs. (For more information, refer the corresponding `wrap_with_accelerator` function in the accompanying `utils.py` file).\n",
    "\n",
    "```\n",
    "# Specify the `FP8RecipeKwargs` (additional argument required to run in `fp8` precision)\n",
    "fp8_kwarg_handler = [FP8RecipeKwargs(backend=\"te\")]\n",
    "\n",
    "# Pass the `FP8RecipeKwargs` to the `Accelerator` init call\n",
    "accelerator = Accelerator(\n",
    "    ...\n",
    "    kwargs_handlers=fp8_kwarg_handler\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "772c6f22",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10 finetuning steps complete!\n",
      "Average time taken per step: 160 milliseconds\n"
     ]
    }
   ],
   "source": [
    "# Restart the notebook (to flush the GPU memory)\n",
    "from utils import restart_jupyter_notebook\n",
    "restart_jupyter_notebook()\n",
    "\n",
    "\n",
    "# Import necessary packages, methods and variables\n",
    "from utils import *\n",
    "\n",
    "\n",
    "# Provide Huggingface Access Token\n",
    "hyperparams.hf_access_token = \"\"\n",
    "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
    "\n",
    "# Provide a directory to cache weights in to avoid downloading them every time.\n",
    "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
    "hyperparams.weights_cache_dir = \"\"\n",
    "\n",
    "# For Llama 2, uncomment this line (also set by default)\n",
    "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
    "\n",
    "# For Llama 3, uncomment this line\n",
    "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
    "\n",
    "hyperparams.mixed_precision = \"fp8\"\n",
    "\n",
    "\n",
    "# Init the model and accelerator wrapper\n",
    "model = init_te_llama_model(hyperparams)\n",
    "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n",
    "\n",
    "\n",
    "# Finetune the model\n",
    "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7cf9c3a",
   "metadata": {},
   "source": [
    "| Models                                                      | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
    "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
    "| HF (baseline)                                               | BF16      | 248                         | 1                       |\n",
    "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16      | 185                         | 1.34                    |\n",
    "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8       | 160                         | 1.55                    |\n",
    "\n",
    "\n",
    "After turning on FP8 precision, we get even more speedup of **55%** (with Llama 2 7B)!\n",
    "\n",
    "#### Llama 3 performance results\n",
    "Running the same tutorial with **Llama 3 8B** yields the following performance numbers:\n",
    "\n",
    "| Models                                                      | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
    "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
    "| HF (baseline)                                               | BF16      | 270                         | 1                       |\n",
    "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16      | 217                         | 1.24                    |\n",
    "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8       | 185                         | 1.46                    |\n",
    "\n",
    "For Llama 3 8B, we get the most speedup of **46%** with FP8 precision!\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95d6c42b",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Llama 2 and Llama 3 implementations. This needs careful initialization of the model such that the model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`. Even with `BF16` precision, `TransformerLayer` provides a speedup over the baseline implementation. With `FP8` precision, the speed up is even more pronounced!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
